#!/usr/bin/python
# Copyright 2020 Makani Technologies LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Generate the switch_def.h file."""

import os
import re
import sys
import textwrap

from makani.avionics.network import network_config


def _WriteSwitchDef(autogen_root, output_dir, script_name, switch_chips):
  """Write switch_def.h.

  Args:
    autogen_root: The MAKANI_HOME-equivalent top directory.
    output_dir: The directory in which to output the files.
    script_name: This script's filename.
    switch_chips: The 'switch_chips' field of the YAML file.
  """
  file_without_extension = os.path.join(output_dir, 'switch_def')
  rel_path = os.path.relpath(file_without_extension, autogen_root)

  # Determine maximum number of ports for switch.
  max_ports = 0
  header_lines = []
  for chip in switch_chips:
    header_lines.append('#define NUM_SWITCH_PORTS_{0} {1}'
                        .format(chip['type'].upper(), chip['num_ports']))
    header_lines.append('#define NUM_SWITCH_FIBER_PORTS_{0} {1}'
                        .format(chip['type'].upper(), chip['num_fiber_ports']))
    max_ports = max(max_ports, chip['num_ports'])
  header_lines.append('#define NUM_SWITCH_PORTS_MAX {0}'.format(max_ports))

  header_guard = re.sub('[/.]', '_', rel_path.upper()) + '_H_'
  parts = [textwrap.dedent("""
      #ifndef {guard}
      #define {guard}

      // Generated by {name}; do not edit.

      {header_defs}

      #endif  // {guard}
      """[1:]).format(guard=header_guard, name=script_name,
                      header_defs='\n'.join(header_lines))]

  with open(file_without_extension + '.h', 'w') as f:
    f.write('\n'.join(parts))


def main(argv):
  flags, argv = network_config.ParseGenerationFlags(argv)

  config = network_config.NetworkConfig(flags.network_file)
  script_name = os.path.basename(argv[0])
  _WriteSwitchDef(flags.autogen_root, flags.output_dir, script_name,
                  config.GetSwitchChips())

if __name__ == '__main__':
  main(sys.argv)
